package com.zhisheng.examples.streaming.ml;

import org.apache.flink.streaming.api.TimeCharacteristic;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.functions.AssignerWithPunctuatedWatermarks;
import org.apache.flink.streaming.api.functions.co.CoMapFunction;
import org.apache.flink.streaming.api.functions.source.SourceFunction;
import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
import org.apache.flink.streaming.api.watermark.Watermark;
import org.apache.flink.streaming.api.windowing.time.Time;
import org.apache.flink.streaming.api.windowing.windows.TimeWindow;
import org.apache.flink.util.Collector;

import java.util.concurrent.TimeUnit;

/**
 * Skeleton for incremental machine learning algorithm consisting of a
 * pre-computed model, which gets updated for the new inputs and new input data
 * for which the job provides predictions.
 *
 * <p>This may serve as a base of a number of algorithms, e.g. updating an
 * incremental Alternating Least Squares model while also providing the
 * predictions.
 *
 * <p>This example shows how to use:
 * <ul>
 * <li>Connected streams
 * <li>CoFunctions
 * <li>Tuple data types
 * </ul>
 * blog：http://www.54tianzhisheng.cn/
 * 微信公众号：zhisheng
 */
public class IncrementalLearningSkeleton {

    public static void main(String[] args) throws Exception {

        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
        env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime);

        DataStream<Integer> trainingData = env.addSource(new FiniteTrainingDataSource());
        DataStream<Integer> newData = env.addSource(new FiniteNewDataSource());

        DataStream<Double[]> model = trainingData
                .assignTimestampsAndWatermarks(new LinearTimestamp())
                .timeWindowAll(Time.of(5000, TimeUnit.MILLISECONDS))
                .apply(new PartialModelBuilder());

        newData.connect(model).map(new Predictor()).print();

        env.execute("Streaming Incremental Learning");
    }

    /**
     * Feeds new data for newData. By default it is implemented as constantly
     * emitting the Integer 1 in a loop.
     */
    public static class FiniteNewDataSource implements SourceFunction<Integer> {
        private static final long serialVersionUID = 1L;
        private int counter;

        @Override
        public void run(SourceContext<Integer> ctx) throws Exception {
            Thread.sleep(15);
            while (counter < 50) {
                ctx.collect(getNewData());
            }
        }

        @Override
        public void cancel() {
            // No cleanup needed
        }

        private Integer getNewData() throws InterruptedException {
            Thread.sleep(5);
            counter++;
            return 1;
        }
    }

    /**
     * Feeds new training data for the partial model builder. By default it is
     * implemented as constantly emitting the Integer 1 in a loop.
     */
    public static class FiniteTrainingDataSource implements SourceFunction<Integer> {
        private static final long serialVersionUID = 1L;
        private int counter = 0;

        @Override
        public void run(SourceContext<Integer> collector) throws Exception {
            while (counter < 8200) {
                collector.collect(getTrainingData());
            }
        }

        @Override
        public void cancel() {
            // No cleanup needed
        }

        private Integer getTrainingData() throws InterruptedException {
            counter++;
            return 1;
        }
    }

    private static class LinearTimestamp implements AssignerWithPunctuatedWatermarks<Integer> {
        private static final long serialVersionUID = 1L;

        private long counter = 0L;

        @Override
        public long extractTimestamp(Integer element, long previousElementTimestamp) {
            return counter += 10L;
        }

        @Override
        public Watermark checkAndGetNextWatermark(Integer lastElement, long extractedTimestamp) {
            return new Watermark(counter - 1);
        }
    }

    /**
     * Builds up-to-date partial models on new training data.
     */
    public static class PartialModelBuilder implements AllWindowFunction<Integer, Double[], TimeWindow> {
        private static final long serialVersionUID = 1L;

        protected Double[] buildPartialModel(Iterable<Integer> values) {
            return new Double[]{1.};
        }

        @Override
        public void apply(TimeWindow window, Iterable<Integer> values, Collector<Double[]> out) throws Exception {
            out.collect(buildPartialModel(values));
        }
    }

    /**
     * Creates newData using the model produced in batch-processing and the
     * up-to-date partial model.
     * <p>
     * By default emits the Integer 0 for every newData and the Integer 1
     * for every model update.
     * </p>
     */
    public static class Predictor implements CoMapFunction<Integer, Double[], Integer> {
        private static final long serialVersionUID = 1L;

        Double[] batchModel = null;
        Double[] partialModel = null;

        @Override
        public Integer map1(Integer value) {
            // Return newData
            return predict(value);
        }

        @Override
        public Integer map2(Double[] value) {
            // Update model
            partialModel = value;
            batchModel = getBatchModel();
            return 1;
        }

        protected Double[] getBatchModel() {
            return new Double[]{0.};
        }

        protected Integer predict(Integer inTuple) {
            return 0;
        }

    }

}
